{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7af2c24d",
   "metadata": {},
   "source": [
    "# 数据预处理\n",
    "- 将原始txt处理成每行的json格式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0f47d4b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1de41397",
   "metadata": {},
   "outputs": [],
   "source": [
    "#统一全角转半角\n",
    "def strQ2B(ustring):\n",
    "    cur_list = []\n",
    "    for s in ustring:\n",
    "        rstring = \"\"\n",
    "        for uchar in s:\n",
    "            inside_code = ord(uchar)\n",
    "            if inside_code == 12288:  # 全角空格直接转换\n",
    "                inside_code = 32\n",
    "            elif (inside_code >= 65281 and inside_code <= 65374):  # 全角字符（除空格）根据关系转化\n",
    "                inside_code -= 65248\n",
    "            rstring += chr(inside_code)\n",
    "        cur_list.append(rstring)\n",
    "    return ''.join(cur_list)\n",
    "\n",
    "\n",
    "#转换特殊字符词组及标点符号\n",
    "trans_punctuations = {'don\\'t':\"do not\",\n",
    "                      '\"':'',\n",
    "                      ';':''\n",
    "                      }\n",
    "def process_data(strs):\n",
    "    for key in trans_punctuations:\n",
    "        strs = strs.replace(key, trans_punctuations[key])\n",
    "    return strQ2B(strs)\n",
    "\n",
    "\n",
    "# 读取原始数据\n",
    "raw_data = []\n",
    "with open('Andersen Fairy Tales.txt', 'r') as f:\n",
    "    for x in f:\n",
    "        x = x.strip().lower()\n",
    "        if x: raw_data.append(process_data(x))\n",
    "\n",
    "#保留长度大于1的句子\n",
    "raw_data = [x for x in raw_data if len(x.split(' '))>1]\n",
    "\n",
    "\n",
    "#保存数据\n",
    "with open(\"./corpus.json\",\"w\") as f:\n",
    "    json.dump(raw_data, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afe8b1ad",
   "metadata": {},
   "source": [
    "# 模型构建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b3ba03b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from collections import  Counter\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.functional import cross_entropy\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80f708d1",
   "metadata": {},
   "source": [
    "## 配置文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7f7fb4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 配置文件\n",
    "config ={\n",
    "    \"elmo\": {\n",
    "            \"activation\": \"relu\",\n",
    "            \"filters\": [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256], [6, 512], [7, 1024]],\n",
    "            \"n_highway\": 2, \n",
    "            \"word_dim\": 300,\n",
    "            \"char_dim\": 50,\n",
    "            \"max_char_token\": 50,\n",
    "            \"min_count\":5,\n",
    "            \"max_length\":256,\n",
    "            \"output_dim\":150,\n",
    "            \"units\":256,\n",
    "            \"n_layers\":2,\n",
    "        },\n",
    "    \"batch_size\":32,\n",
    "    \"epochs\":50,\n",
    "    \"lr\":0.00001,\n",
    "}\n",
    "\n",
    "# 保存路径\n",
    "model_save_path=\"./elmo_model\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c908fc59",
   "metadata": {},
   "source": [
    "## 数据集构建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "867a490f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取语料\n",
    "with open(\"./corpus.json\") as f:\n",
    "    corpus = json.load(f)\n",
    "    corpus = corpus[:1000] # 测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "853721b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cuda\n"
     ]
    }
   ],
   "source": [
    "# 检测是否有可用GPU\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "print('device: ' + str(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1efec278",
   "metadata": {},
   "outputs": [],
   "source": [
    "#分词器\n",
    "class Tokenizer:\n",
    "    def __init__(self, word2id,ch2id):\n",
    "        self.word2id = word2id\n",
    "        self.ch2id = ch2id\n",
    "        self.id2word = {i: word for word, i in word2id.items()}\n",
    "        self.id2ch = {i: char for char, i in ch2id.items()}\n",
    "    \n",
    "    def tokenize(self,text,max_length=512,max_char=50):\n",
    "        oov_id, pad_id = self.word2id.get(\"<oov>\"), self.word2id.get(\"<pad>\")\n",
    "        w = torch.LongTensor(max_length).fill_(pad_id)\n",
    "        words = text.lower().split()\n",
    "        for i, wi in enumerate(words[:max_length]):\n",
    "            w[i] = self.word2id.get(wi, oov_id)\n",
    "        oov_id, pad_id = self.ch2id.get(\"<oov>\"), self.ch2id.get(\"<pad>\")\n",
    "        c = torch.LongTensor(max_length,max_char).fill_(pad_id)\n",
    "        for i, wi in enumerate(words[:max_length]):\n",
    "            for j,wij in enumerate(wi[:max_char]):\n",
    "                c[i][j]=self.ch2id.get(wij, oov_id)\n",
    "        return w , c , len(words[:max_length])\n",
    "\n",
    "    def save(self,path):\n",
    "        try:\n",
    "            os.mkdir(path)\n",
    "        except:\n",
    "            pass\n",
    "        tok ={\n",
    "            \"word2id\":self.word2id,\n",
    "            \"ch2id\":self.ch2id\n",
    "        }\n",
    "        with open(f\"{path}/tokenizer.json\",\"w\") as f:\n",
    "            json.dump(tok,f,indent=4)\n",
    "            \n",
    "            \n",
    "# 从语料中构建\n",
    "def from_corpus(corpus,min_count=5):\n",
    "    word_count = Counter()\n",
    "    for sentence in corpus:\n",
    "        word_count.update(sentence.split())\n",
    "    word_count = list(word_count.items())\n",
    "    word_count.sort(key=lambda x: x[1], reverse=True)\n",
    "    for i, (word, count) in enumerate(word_count):\n",
    "        if count < min_count:\n",
    "            break\n",
    "    vocab = word_count[:i]\n",
    "    vocab = [v[0] for v in vocab]\n",
    "    word_lexicon = {}\n",
    "    for special_word in ['<oov>', '<pad>']:\n",
    "        if special_word not in word_lexicon:\n",
    "            word_lexicon[special_word] = len(word_lexicon)\n",
    "    for word in vocab:\n",
    "        if word not in word_lexicon:\n",
    "            word_lexicon[word] = len(word_lexicon)\n",
    "    char_lexicon = {}\n",
    "    for special_char in ['<oov>', '<pad>']:\n",
    "        if special_char not in char_lexicon:\n",
    "            char_lexicon[special_char] = len(char_lexicon)\n",
    "    for sentence in corpus:\n",
    "        for word in sentence.split():\n",
    "            for ch in word:\n",
    "                if ch not in char_lexicon:\n",
    "                    char_lexicon[ch] = len(char_lexicon)\n",
    "    return Tokenizer(word_lexicon,char_lexicon)\n",
    "\n",
    "\n",
    "# 从checkpoint中构建\n",
    "def from_file(path):\n",
    "    with open(f\"{path}/tokenizer.json\") as f:\n",
    "        d = json.load(f)\n",
    "    return Tokenizer(d[\"word2id\"],d[\"ch2id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1f77f6f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化分词器\n",
    "tokenizer = from_corpus(corpus, config[\"elmo\"][\"min_count\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a5c60f5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ELMO数据集生成器\n",
    "class ELMoDataSet(Dataset):\n",
    "    def __init__(self,corpus,tokenizer):\n",
    "        self.corpus=corpus\n",
    "        self.tokenizer=tokenizer\n",
    "        \n",
    "    def __getitem__(self, idx):\n",
    "        text = self.corpus[idx]\n",
    "        w,c,i= self.tokenizer.tokenize(text,max_length=config[\"elmo\"][\"max_length\"],max_char=config[\"elmo\"][\"max_char_token\"])\n",
    "        return w,c,i\n",
    "   \n",
    "    def __len__(self):\n",
    "        return len(self.corpus)\n",
    "\n",
    "# 初始化数据集生成器\n",
    "data = ELMoDataSet(corpus,tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b24ad733",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化Pytorch框架的数据生成器\n",
    "data_loader = DataLoader(data, batch_size=config[\"batch_size\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a2c87cc",
   "metadata": {},
   "source": [
    "## 模型初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "414187bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Based upon https://gist.github.com/Redchards/65f1a6f758a1a5c5efb56f83933c3f6e\n",
    "# Original Paper https://arxiv.org/abs/1505.00387\n",
    "# 我们用残差网络替代HighWay\n",
    "class HighWay(nn.Module):\n",
    "    def __init__(self, input_dim, num_layers=1,activation= nn.functional.relu):\n",
    "        super(HighWay, self).__init__()\n",
    "        self._input_dim = input_dim\n",
    "        self._layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)])\n",
    "        self._activation = activation\n",
    "        for layer in self._layers:\n",
    "            layer.bias[input_dim:].data.fill_(1)\n",
    "            \n",
    "    def forward(self, inputs):\n",
    "        current_input = inputs\n",
    "        for layer in self._layers:\n",
    "            projected_input = layer(current_input)\n",
    "            linear_part = current_input\n",
    "            nonlinear_part = projected_input[:, (0 * self._input_dim):(1 * self._input_dim)]\n",
    "            gate = projected_input[:, (1 * self._input_dim):(2 * self._input_dim)]\n",
    "            nonlinear_part = self._activation(nonlinear_part)\n",
    "            gate = torch.sigmoid(gate)\n",
    "            current_input = gate * linear_part + (1 - gate) * nonlinear_part\n",
    "        return current_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "90cf8e3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ELMo(nn.Module):\n",
    "    def __init__(self,tokenizer,config):\n",
    "        super(ELMo, self).__init__()\n",
    "        self.config=config\n",
    "        self.tokenizer = tokenizer\n",
    "        self.word_embedding = nn.Embedding(len(tokenizer.word2id),config[\"elmo\"][\"word_dim\"],padding_idx=tokenizer.word2id.get(\"<pad>\"))\n",
    "        self.char_embedding = nn.Embedding(len(tokenizer.ch2id),config[\"elmo\"][\"char_dim\"],padding_idx=tokenizer.ch2id.get(\"<pad>\"))\n",
    "        self.output_dim = config[\"elmo\"][\"output_dim\"]\n",
    "        activation = config[\"elmo\"][\"activation\"]\n",
    "        if activation==\"relu\":\n",
    "            self.act = nn.ReLU()\n",
    "        elif activation==\"tanh\":\n",
    "            self.act=nn.Tanh()\n",
    "        self.emb_dim = config[\"elmo\"][\"word_dim\"]\n",
    "        self.convolutions = []\n",
    "        filters = config[\"elmo\"][\"filters\"]\n",
    "        char_dim = config[\"elmo\"][\"char_dim\"]\n",
    "        for i, (width, num) in enumerate(filters):\n",
    "            conv = nn.Conv1d(in_channels=char_dim,\n",
    "                             out_channels=num,\n",
    "                             kernel_size=width,\n",
    "                             bias=True\n",
    "                             )\n",
    "            self.convolutions.append(conv)\n",
    "        self.convolutions = nn.ModuleList(self.convolutions)\n",
    "        self.n_filters = sum(f[1] for f in filters)\n",
    "        self.n_highway = config[\"elmo\"][\"n_highway\"]\n",
    "        self.highways = HighWay(self.n_filters, self.n_highway, activation=self.act)\n",
    "        self.emb_dim += self.n_filters\n",
    "        self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True)\n",
    "        self.f=[nn.LSTM(input_size = config[\"elmo\"][\"output_dim\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True)]\n",
    "        self.b=[nn.LSTM(input_size = config[\"elmo\"][\"output_dim\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True)]\n",
    "        for _ in range(config[\"elmo\"][\"n_layers\"]-1):\n",
    "            self.f.append(nn.LSTM(input_size = config[\"elmo\"][\"units\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True))\n",
    "            self.b.append(nn.LSTM(input_size = config[\"elmo\"][\"units\"], hidden_size = config[\"elmo\"][\"units\"], batch_first=True))\n",
    "        self.f = nn.ModuleList(self.f)\n",
    "        self.b = nn.ModuleList(self.b)\n",
    "        self.ln = nn.Linear(in_features=config[\"elmo\"][\"units\"], out_features=len(tokenizer.word2id))\n",
    "        \n",
    "    def forward(self, word_inp, chars_inp):\n",
    "        embs = []\n",
    "        batch_size, seq_len = word_inp.size(0), word_inp.size(1)\n",
    "        word_emb = self.word_embedding(Variable(word_inp))\n",
    "        embs.append(word_emb)\n",
    "        chars_inp = chars_inp.view(batch_size * seq_len, -1)\n",
    "        char_emb = self.char_embedding(Variable(chars_inp))\n",
    "        char_emb = char_emb.transpose(1, 2)\n",
    "        convs = []\n",
    "        for i in range(len(self.convolutions)):\n",
    "            convolved = self.convolutions[i](char_emb)\n",
    "            convolved, _ = torch.max(convolved, dim=-1)\n",
    "            convolved = self.act(convolved)\n",
    "            convs.append(convolved)\n",
    "        char_emb = torch.cat(convs, dim=-1)\n",
    "        char_emb = self.highways(char_emb)\n",
    "        embs.append(char_emb.view(batch_size, -1, self.n_filters))\n",
    "        token_embedding = torch.cat(embs, dim=2)\n",
    "        embeddings = self.projection(token_embedding)\n",
    "        fs = [embeddings]         \n",
    "        bs = [embeddings]\n",
    "        for fl,bl in zip(self.f,self.b):\n",
    "            o_f,_ = fl(fs[-1])\n",
    "            fs.append(o_f)\n",
    "            o_b,_ = bl(torch.flip(bs[-1],dims=[1,]))\n",
    "            bs.append(torch.flip(o_b,dims=(1,)))\n",
    "        return fs,bs\n",
    "    \n",
    "    def save_model(self,path):\n",
    "        try:\n",
    "            os.mkdir(path)\n",
    "        except:\n",
    "            pass\n",
    "        torch.save(self.state_dict(),f'{path}/model.pt')\n",
    "        with open(f\"{path}/config.json\",\"w\") as f:\n",
    "            json.dump(self.config,f,indent=4)\n",
    "        self.tokenizer.save(path)\n",
    "        \n",
    "    @classmethod\n",
    "    def from_checkpoint(cls,path,device):\n",
    "        with open(f\"{path}/config.json\") as f:\n",
    "            config = json.load(f)\n",
    "        tokenizer = Tokenizer.from_file(path)\n",
    "        model = cls(tokenizer,config)\n",
    "        model.load_state_dict(torch.load(f'{path}/model.pt',map_location=device))\n",
    "        return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4805db12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ELMo(\n",
       "  (word_embedding): Embedding(1122, 300, padding_idx=1)\n",
       "  (char_embedding): Embedding(39, 50, padding_idx=1)\n",
       "  (act): ReLU()\n",
       "  (convolutions): ModuleList(\n",
       "    (0): Conv1d(50, 32, kernel_size=(1,), stride=(1,))\n",
       "    (1): Conv1d(50, 32, kernel_size=(2,), stride=(1,))\n",
       "    (2): Conv1d(50, 64, kernel_size=(3,), stride=(1,))\n",
       "    (3): Conv1d(50, 128, kernel_size=(4,), stride=(1,))\n",
       "    (4): Conv1d(50, 256, kernel_size=(5,), stride=(1,))\n",
       "    (5): Conv1d(50, 512, kernel_size=(6,), stride=(1,))\n",
       "    (6): Conv1d(50, 1024, kernel_size=(7,), stride=(1,))\n",
       "  )\n",
       "  (highways): HighWay(\n",
       "    (_layers): ModuleList(\n",
       "      (0): Linear(in_features=2048, out_features=4096, bias=True)\n",
       "      (1): Linear(in_features=2048, out_features=4096, bias=True)\n",
       "    )\n",
       "    (_activation): ReLU()\n",
       "  )\n",
       "  (projection): Linear(in_features=2348, out_features=150, bias=True)\n",
       "  (f): ModuleList(\n",
       "    (0): LSTM(150, 256, batch_first=True)\n",
       "    (1): LSTM(256, 256, batch_first=True)\n",
       "  )\n",
       "  (b): ModuleList(\n",
       "    (0): LSTM(150, 256, batch_first=True)\n",
       "    (1): LSTM(256, 256, batch_first=True)\n",
       "  )\n",
       "  (ln): Linear(in_features=256, out_features=1122, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = ELMo(tokenizer,config)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8ad1d0b",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b1277113",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8149a2ccb4dd408398918be4c046df51",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 3546.50244140625\n",
      "Epoch: 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "269103df89af41c3920ea3527295ae89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 3521.490478515625\n",
      "Epoch: 3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f7846e7af9f4d43a5745954bcf2fb5e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 3480.632568359375\n",
      "Epoch: 4\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2046b6ecbc0a41f9a5b1c824bd1820eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 3414.54736328125\n",
      "Epoch: 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6179a63d480c492a844270e8dfcf58da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 3278.797607421875\n",
      "Epoch: 6\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "74b18065f2c14dc9bcc4280ad081f93c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 2848.858154296875\n",
      "Epoch: 7\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b0f6bf57a60145718d24a6a4a3cded16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 2211.5849609375\n",
      "Epoch: 8\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f37146aa0032426bb6dbdabf255860f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 1741.710205078125\n",
      "Epoch: 9\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "11e92d641d3b440485be364d32240165",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 1446.1951904296875\n",
      "Epoch: 10\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4fa5675d68a45e1ac500a0cb0784e85",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 1268.07275390625\n",
      "Epoch: 11\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "79c6f8166bd34d918756ac2ad1c4a9d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 1163.1085205078125\n",
      "Epoch: 12\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4f547d4a674441e81359db4de5c6931",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 1097.5172119140625\n",
      "Epoch: 13\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b3a3296c1003467db97931acab637c4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 1054.2308349609375\n",
      "Epoch: 14\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "25e85c000be74db8b460704fbc24f8a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 1023.152587890625\n",
      "Epoch: 15\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9be97e8d1f8f43e3a8a922719ae0b0d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 999.1707763671875\n",
      "Epoch: 16\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3cefb96da62646f08b46a0423fa872cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 980.1627197265625\n",
      "Epoch: 17\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d3ce85d29c448dfbd6325d6a7eba605",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 964.4329833984375\n",
      "Epoch: 18\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa2996b3ce85413399e426be07d72e7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 951.1318969726562\n",
      "Epoch: 19\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0293ff98543842ce8e7db4e92e507c04",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "total_loss: 939.527587890625\n",
      "Epoch: 20\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "346fcab8245c4457b6fc23382377a3cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-18-fdfa96df6270>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     21\u001b[0m             \u001b[0mbll\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbl\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     22\u001b[0m             \u001b[0mloss\u001b[0m\u001b[1;33m+=\u001b[0m\u001b[0mloss_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfll\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mloss_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbll\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m         \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     24\u001b[0m         \u001b[0mopt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m         \u001b[0mopt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\kangb\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m    183\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    184\u001b[0m         \"\"\"\n\u001b[1;32m--> 185\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    186\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    187\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\kangb\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m    125\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[0;32m    126\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 127\u001b[1;33m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[0;32m    128\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    129\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "opt = optim.Adam(model.parameters(),lr = config[\"lr\"])\n",
    "loss_function = torch.nn.NLLLoss()\n",
    "for epoch in range(config[\"epochs\"]):\n",
    "    total_loss = 0\n",
    "    print(f\"Epoch: {epoch+1}\")\n",
    "    for batch in tqdm(data_loader):\n",
    "        total_loss = 0\n",
    "        w , c , i = batch\n",
    "        w = w.to(device)\n",
    "        c = c.to(device)\n",
    "        f, b = model(w,c)\n",
    "        f, b = f[-1], b[-1]\n",
    "        k_max=torch.max(i)\n",
    "        loss = 0\n",
    "        for k in range(1,k_max):\n",
    "            fpass=f[:,k-1,:]\n",
    "            bpass=b[:,k-1,:]\n",
    "            fl = model.ln(fpass).squeeze()\n",
    "            bl = model.ln(bpass).squeeze()\n",
    "            fll = torch.nn.functional.log_softmax(fl,dim=1).squeeze()\n",
    "            bll = torch.nn.functional.log_softmax(bl,dim=1).squeeze()\n",
    "            loss+=loss_function(fll,w[:,k])+loss_function(bll,w[:,k])\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "        model.zero_grad()\n",
    "        total_loss += loss.detach().item()\n",
    "    model.save_model(model_save_path)\n",
    "    print('total_loss:',total_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9120a89",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
